{
 "cells": [
  {
   "cell_type": "raw",
   "id": "weird-vintage",
   "metadata": {
    "papermill": {
     "duration": 0.030561,
     "end_time": "2021-04-05T23:17:30.506820",
     "exception": false,
     "start_time": "2021-04-05T23:17:30.476259",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "This is a simple demonstration of Debiased Machine Learning estimator for the Conditional Average Treatment Effect. \n",
    "Goal is to estimate the effect of 401(k) eligibility on net financial assets for each value of income. \n",
    "Data set is the same as in (Chernozhukov, Hansen, 2004). \n",
    "\n",
    "\n",
    "The method is based on the following paper. \n",
    "\n",
    "Title:  Debiased Machine Learning of Conditional Average Treatment Effect and Other Causal Functions\n",
    "\n",
    "Authors: Semenova, Vira and Chernozhukov, Victor. \n",
    "\n",
    "Arxiv version: https://arxiv.org/pdf/1702.06240.pdf\n",
    "\n",
    "Published version with replication code: https://academic.oup.com/ectj/advance-article/doi/10.1093/ectj/utaa027/5899048\n",
    "\n",
    "\n",
    "[1]Victor Chernozhukov and Christian Hansen. The impact of 401(k) participation on the wealth distribution: An instrumental quantile regression analysis. Review of Economics and Statistics, 86(3):735–751, 2004."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ruled-nurse",
   "metadata": {
    "papermill": {
     "duration": 0.026442,
     "end_time": "2021-04-05T23:17:30.551541",
     "exception": false,
     "start_time": "2021-04-05T23:17:30.525099",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "Background\n",
    "\n",
    "The target function is Conditional Average Treatment Effect, defined as \n",
    "\n",
    "$$ g(x)=E [ Y(1) - Y(0) |X=x], $$ \n",
    "\n",
    "where $Y(1)$ and $Y(0)$ are potential outcomes in treated and control group. In our case, $Y(1)$ is the potential Net Financial Assets if a subject is eligible for 401(k), and $Y(0)$ is the potential Net Financial Assets if a subject is ineligible. $X$ is a covariate of interest, in this case, income.\n",
    "$ g(x)$ shows expected effect of eligibility on NET TFA for a subject whose income level is $x$.\n",
    "\n",
    "\n",
    "\n",
    "If eligibility indicator is independent of $Y(1), Y(0)$, given pre-401-k assignment characteristics $Z$, the function can expressed in terms of observed data (as opposed to hypothetical, or potential outcomes). Observed data consists of  realized NET TFA $Y = D Y(1) + (1-D) Y(0)$, eligibility indicator $D$, and covariates $Z$ which includes $X$, income. The expression for $g(x)$ is\n",
    "\n",
    "$$ g(x) = E [ Y (\\eta_0) \\mid X=x], $$\n",
    "where the transformed outcome variable is\n",
    "\n",
    "$$Y (\\eta) = \\dfrac{D}{s(Z)} \\left( Y - \\mu(1,Z) \\right) - \\dfrac{1-D}{1-s(Z)} \\left( Y - \\mu(0,Z) \\right) + \\mu(1,Z) - \\mu(0,Z),$$\n",
    "\n",
    "the probability of eligibility is \n",
    "\n",
    "$$s_0(z) = Pr (D=1 \\mid Z=z),$$ \n",
    "\n",
    "the expected net financial asset given $D =d \\in \\{1,0\\}$ and $Z=z$ is\n",
    "\n",
    "$$ \\mu(d,z) = E[ Y \\mid Z=z, D=d]. $$\n",
    "\n",
    "Our goal is to estimate $g(x)$.\n",
    "\n",
    "\n",
    "In step 1, we estimate the unknown functions $s_0(z),  \\mu(1,z),  \\mu(0,z)$ and plug them into $Y (\\eta)$.\n",
    "\n",
    "\n",
    "In step 2, we approximate the function $g(x)$ by a linear combination of basis functions:\n",
    "\n",
    "$$ g(x) = p(x)' \\beta_0, $$\n",
    "\n",
    "\n",
    "where $p(x)$ is a vector of polynomials or splines and\n",
    "\n",
    "$$ \\beta_0 = (E p(X) p(X))^{-1} E p(X) Y (\\eta_0) $$\n",
    "\n",
    "is the best linear predictor. We report\n",
    "\n",
    "$$\n",
    "\\widehat{g}(x) = p(x)' \\widehat{\\beta},\n",
    "$$\n",
    "\n",
    "where $\\widehat{\\beta}$ is the ordinary least squares estimate of $\\beta_0$ defined on the random sample $(X_i, D_i, Y_i)_{i=1}^N$\n",
    "\n",
    "$$\n",
    "\t\\widehat{\\beta} :=\\left( \\dfrac{1}{N} \\sum_{i=1}^N p(X_i) p(X_i)' \\right)^{-1} \\dfrac{1}{N} \\sum_{i=1}^N  p(X_i)Y_i(\\widehat{\\eta})\n",
    "$$\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "narrative-mother",
   "metadata": {
    "_execution_state": "idle",
    "_uuid": "051d70d956493feee0c6d64651c6a088724dca2a",
    "papermill": {
     "duration": 3.170404,
     "end_time": "2021-04-05T23:17:33.746980",
     "exception": false,
     "start_time": "2021-04-05T23:17:30.576576",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "## load packages\n",
    "rm(list=ls())\n",
    "library(foreign)\n",
    "library(quantreg)\n",
    "library(splines)\n",
    "library(lattice)\n",
    "#library(mnormt);\n",
    "library(Hmisc)\n",
    "library(fda);\n",
    "library(hdm)\n",
    "library(randomForest)\n",
    "library(ranger)\n",
    "library(sandwich)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "sublime-brazil",
   "metadata": {
    "papermill": {
     "duration": 0.361163,
     "end_time": "2021-04-05T23:17:34.137259",
     "exception": false,
     "start_time": "2021-04-05T23:17:33.776096",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "## 401k dataset\n",
    "data(pension)\n",
    "pension$net_tfa<-pension$net_tfa/10000\n",
    "## covariate of interest -- log income -- \n",
    "pension$inc = log(pension$inc)\n",
    "#pension$inc[is.na(pension$inc)]<-0\n",
    "pension<-pension[!is.na(pension$inc) & pension$inc!=-Inf & pension$inc !=Inf,]\n",
    "\n",
    "\n",
    "## outcome variable -- total net financial assets\n",
    "Y=pension$net_tfa\n",
    "## binary treatment --  indicator of 401(k) eligibility\n",
    "D=pension$e401\n",
    "\n",
    "\n",
    "X=pension$inc\n",
    "## target parameter is CATE = E[ Y(1) - Y(0) | X] \n",
    "\n",
    "\n",
    "## raw covariates so that Y(1) and Y(0) are independent of D given Z\n",
    "Z = pension[,c(\"age\",\"inc\",\"fsize\",\"educ\",\"male\",\"db\",\"marr\",\"twoearn\",\"pira\",\"hown\",\"hval\",\"hequity\",\"hmort\",\n",
    "              \"nohs\",\"hs\",\"smcol\")]\n",
    "\n",
    "\n",
    "y_name   <- \"net_tfa\";\n",
    "d_name    <- \"e401\";\n",
    "form_z    <- \"(poly(age, 6) + poly(inc, 8) + poly(educ, 4) + poly(fsize,2) + as.factor(marr) + as.factor(twoearn) + as.factor(db) + as.factor(pira) + as.factor(hown))^2\";\n",
    "\n",
    "\n",
    "\n",
    "cat(sprintf(\"\\n sample size is %g \\n\", length(Y) ))\n",
    "cat(sprintf(\"\\n num raw covariates z is %g \\n\", dim(Z)[2] ))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "accessible-genetics",
   "metadata": {
    "papermill": {
     "duration": 0.025694,
     "end_time": "2021-04-05T23:17:34.196413",
     "exception": false,
     "start_time": "2021-04-05T23:17:34.170719",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "In Step 1, we estimate three functions:\n",
    "\n",
    "1. probability of treatment assignment $s_0(z)$ \n",
    "\n",
    "2.-3. regression functions $\\mu_0(1,z)$ and $\\mu_0(0,z)$.  \n",
    "\n",
    "We use the cross-fitting procedure with $K=2$ holds. For definition of cross-fitting with $K$ folds, check the sample splitting in ```DML2.for.PLM``` function defined in https://www.kaggle.com/victorchernozhukov/debiased-ml-for-partially-linear-model-in-r\n",
    "\n",
    "For each function, we try random forest.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "familiar-pierce",
   "metadata": {
    "papermill": {
     "duration": 0.026737,
     "end_time": "2021-04-05T23:17:34.248852",
     "exception": false,
     "start_time": "2021-04-05T23:17:34.222115",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "First Stage: estimate $\\mu_0(1,z)$ and $\\mu_0(0,z)$ and $s_0(z)$ by lasso"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "hydraulic-housing",
   "metadata": {
    "papermill": {
     "duration": 0.053931,
     "end_time": "2021-04-05T23:17:34.328657",
     "exception": false,
     "start_time": "2021-04-05T23:17:34.274726",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "first_stage_lasso<-function(data,d_name,y_name, form_z, seed=1) {\n",
    "\n",
    "  # Sample size\n",
    "  N<-dim(data)[1]\n",
    "  # Estimated regression function in control group\n",
    "  mu0.hat<-rep(1,N) \n",
    "  # Estimated regression function in treated group\n",
    "  mu1.hat<-rep(1,N) \n",
    "  # Propensity score\n",
    "  s.hat<-rep(1,N)\n",
    "  seed=1\n",
    "  ## define sample splitting\n",
    "  set.seed(seed)\n",
    "  inds.train=sample(1:N,floor(N/2))\n",
    "  inds.eval=setdiff(1:N,inds.train)\n",
    "\n",
    "  print (\"Estimate treatment probability, first half\")\n",
    "  ## conditional probability of 401 k eligibility (i.e., propensity score) based on random forest\n",
    "  fitted.lasso.pscore<-rlassologit(as.formula(paste0(d_name,\"~\",form_z )),data=data[inds.train,])\n",
    "\n",
    "  s.hat[inds.eval]<-predict(fitted.lasso.pscore,data[inds.eval,],type=\"response\")\n",
    "  print (\"Estimate treatment probability, second half\")\n",
    "  fitted.lasso.pscore<-rlassologit(as.formula(paste0(d_name,\"~\",form_z )),data=data[inds.eval,])\n",
    "  s.hat[inds.train]<-predict( fitted.lasso.pscore,data[inds.train,],type=\"response\")\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "  data1<-data\n",
    "  data1[,d_name]<-1\n",
    "\n",
    "  data0<-data\n",
    "  data0[,d_name]<-0\n",
    "\n",
    "  print (\"Estimate expectation function, first half\") \n",
    "  fitted.lasso.mu<-rlasso(as.formula(paste0(y_name,\"~\",d_name,\"+(\",form_z,\")\" )),data=data[inds.train,])\n",
    "  mu1.hat[inds.eval]<-predict( fitted.lasso.mu,data1[inds.eval,])\n",
    "  mu0.hat[inds.eval]<-predict( fitted.lasso.mu,data0[inds.eval,]) \n",
    "  \n",
    "  print (\"Estimate expectation function, second half\") \n",
    "  fitted.lasso.mu<-rlasso(as.formula(paste0(y_name,\"~\",d_name,\"+(\",form_z,\")\" )),data=data[inds.eval,])\n",
    "  mu1.hat[inds.train]<-predict( fitted.lasso.mu,data1[inds.train,])\n",
    "  mu0.hat[inds.train]<-predict( fitted.lasso.mu,data0[inds.train,]) \n",
    " \n",
    "  return (list(mu1.hat=mu1.hat,\n",
    "              mu0.hat=mu0.hat,\n",
    "              s.hat=s.hat))\n",
    "    \n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "synthetic-collar",
   "metadata": {
    "papermill": {
     "duration": 0.023914,
     "end_time": "2021-04-05T23:17:34.379357",
     "exception": false,
     "start_time": "2021-04-05T23:17:34.355443",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "First Stage: estimate $\\mu_0(1,z)$ and $\\mu_0(0,z)$ and $s_0(z)$ by random forest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "boolean-fighter",
   "metadata": {
    "papermill": {
     "duration": 0.044367,
     "end_time": "2021-04-05T23:17:34.447791",
     "exception": false,
     "start_time": "2021-04-05T23:17:34.403424",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "first_stage_rf<-function(Y,D,X,Z,seed=1) {\n",
    "\n",
    "  # Sample size\n",
    "  N<-length(D)\n",
    "  # Estimated regression function in control group\n",
    "  mu0.hat<-rep(1,N) \n",
    "  # Estimated regression function in treated group\n",
    "  mu1.hat<-rep(1,N) \n",
    "  # Propensity score\n",
    "  s.hat<-rep(1,N)\n",
    "    \n",
    "    \n",
    "  ## define sample splitting\n",
    "   set.seed(seed)\n",
    "  inds.train=sample(1:N,floor(N/2))\n",
    "  inds.eval=setdiff(1:N,inds.train)\n",
    "\n",
    "    print (\"Estimate treatment probability, first half\")\n",
    "  ## conditional probability of 401 k eligibility (i.e., propensity score) based on random forest\n",
    "    D.f<-as.factor(as.character(D))\n",
    "   fitted.rf.pscore<-randomForest(Z,D.f,subset=inds.train)\n",
    "   s.hat[inds.eval]<-predict(fitted.rf.pscore,Z[inds.eval,],type=\"prob\")[,2]\n",
    "    print (\"Estimate treatment probability, second half\")\n",
    "   fitted.rf<-randomForest(Z,D.f,subset=inds.eval) \n",
    "    s.hat[inds.train]<-predict(fitted.rf.pscore,Z[inds.train,],type=\"prob\")[,2]\n",
    "    \n",
    "  ## conditional expected net financial assets (i.e.,  regression function) based on random forest\n",
    "  \n",
    "  covariates<-cbind(Z,D)\n",
    "    \n",
    "  covariates1<-cbind(Z,D=rep(1,N))\n",
    "  covariates0<-cbind(Z,D=rep(0,N))\n",
    "  \n",
    "  print (\"Estimate expectation function, first half\") \n",
    "  fitted.rf.mu<-randomForest(cbind(Z,D),Y,subset=inds.train)\n",
    "  mu1.hat[inds.eval]<-predict( fitted.rf.mu,covariates1[inds.eval,])\n",
    "  mu0.hat[inds.eval]<-predict( fitted.rf.mu,covariates0[inds.eval,]) \n",
    "  \n",
    "  print (\"Estimate expectation function, second half\") \n",
    "   fitted.rf.mu<-randomForest(cbind(Z,D),Y,subset=inds.eval) \n",
    "   mu1.hat[inds.train]<-predict( fitted.rf.mu,covariates1[inds.train,])\n",
    "   mu0.hat[inds.train]<-predict( fitted.rf.mu,covariates0[inds.train,]) \n",
    " \n",
    "  return (list(mu1.hat=mu1.hat,\n",
    "              mu0.hat=mu0.hat,\n",
    "              s.hat=s.hat))\n",
    "    \n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "common-poverty",
   "metadata": {
    "papermill": {
     "duration": 0.023746,
     "end_time": "2021-04-05T23:17:34.496537",
     "exception": false,
     "start_time": "2021-04-05T23:17:34.472791",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "In Step 2, we approximate $Y(\\eta_0)$ by a vector of basis functions. There are two use cases:\n",
    "****\n",
    "2.A. Group Average Treatment Effect, described above\n",
    "\n",
    "\n",
    "2.B. Average Treatment Effect conditional on income value. There are three smoothing options:\n",
    "\n",
    "1. splines offered in ```least_squares_splines```\n",
    "\n",
    "2. orthogonal polynomials with the highest degree chosen by cross-validation ```least_squares_series```\n",
    "\n",
    "3. standard polynomials with the highest degree input by user ```least_squares_series_old```\n",
    "\n",
    "\n",
    "The default option is option 3."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "connected-somerset",
   "metadata": {
    "papermill": {
     "duration": 0.024285,
     "end_time": "2021-04-05T23:17:34.545072",
     "exception": false,
     "start_time": "2021-04-05T23:17:34.520787",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "2.A. The simplest use case of Conditional Average Treatment Effect is GATE, or Group Average Treatment Effect. Partition the support of income as\n",
    "\n",
    "$$ - \\infty = \\ell_0 < \\ell_1 < \\ell_2 \\dots \\ell_K = \\infty $$\n",
    "\n",
    "define intervals $I_k = [ \\ell_{k-1}, \\ell_{k})$. Let $X$ be income covariate. For $X$, define a group indicator \n",
    "\n",
    "$$ G_k(X) = 1[X \\in I_k], $$\n",
    "\n",
    "and the vector of basis functions \n",
    "\n",
    "$$ p(X) = (G_1(X), G_2(X), \\dots, G_K(X)) $$\n",
    "\n",
    "Then, the Best Linear Predictor $\\beta_0$ vector shows the average treatment effect for each group."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "japanese-bishop",
   "metadata": {
    "papermill": {
     "duration": 168.003121,
     "end_time": "2021-04-05T23:20:22.572302",
     "exception": false,
     "start_time": "2021-04-05T23:17:34.569181",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "## estimate first stage functions by random forest\n",
    "## may take a while\n",
    "fs.hat.rf = first_stage_rf(Y,D,X,Z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "graphic-parameter",
   "metadata": {
    "papermill": {
     "duration": 0.063616,
     "end_time": "2021-04-05T23:20:22.660625",
     "exception": false,
     "start_time": "2021-04-05T23:20:22.597009",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "X=pension$inc\n",
    "fs.hat<-fs.hat.rf\n",
    "min_cutoff=0.01\n",
    "# regression function\n",
    "mu1.hat<-fs.hat[[\"mu1.hat\"]]\n",
    "mu0.hat<-fs.hat[[\"mu0.hat\"]]\n",
    "# propensity score\n",
    "s.hat<-fs.hat[[\"s.hat\"]]\n",
    "s.hat<-sapply(s.hat,max,min_cutoff)\n",
    "### Construct Orthogonal Signal \n",
    "\n",
    "\n",
    "RobustSignal<-(Y - mu1.hat)*D/s.hat - (Y - mu0.hat)*(1-D)/(1-s.hat) + mu1.hat - mu0.hat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "numerical-niger",
   "metadata": {
    "papermill": {
     "duration": 0.046156,
     "end_time": "2021-04-05T23:20:22.731233",
     "exception": false,
     "start_time": "2021-04-05T23:20:22.685077",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "qtmax <- function(C, S=10000, alpha)\n",
    "  {;\n",
    "   p <- nrow(C);\n",
    "   tmaxs <- apply(abs(matrix(rnorm(p*S), nrow = p, ncol = S)), 2, max);\n",
    "   return(quantile(tmaxs, 1-alpha));\n",
    "  };\n",
    "\n",
    "# This function computes the square root of a symmetric matrix using the spectral decomposition;\n",
    "\n",
    "\n",
    "group_average_treatment_effect<-function(X,Y,max_grid=5,alpha=0.05, B=10000) {\n",
    "    \n",
    " grid<-quantile(X,probs=c((0:max_grid)/max_grid))\n",
    " X.raw<-matrix(NA, nrow=length(Y),ncol=length(grid)-1) \n",
    "    \n",
    " for (k in 2:((length(grid)))) {\n",
    "       X.raw[,k-1]<-sapply(X, function (x) ifelse (x>=grid[k-1] & x<grid[k],1,0) )\n",
    " }\n",
    " k=length(grid)\n",
    " X.raw[,k-1]<-sapply(X, function (x) ifelse (x>=grid[k-1] & x<=grid[k],1,0) )\n",
    "                                  \n",
    " ols.fit<- lm(Y~X.raw-1)\n",
    " coefs   <- coef(ols.fit)\n",
    " vars <- names(coefs)\n",
    " HCV.coefs <- vcovHC(ols.fit, type = 'HC')\n",
    " coefs.se <- sqrt(diag(HCV.coefs)) # White std errors\n",
    " ## this is an identity matrix\n",
    "                     ## qtmax is simplified\n",
    " C.coefs  <- (diag(1/sqrt(diag(HCV.coefs)))) %*% HCV.coefs %*% (diag(1/sqrt(diag(HCV.coefs))));                    \n",
    "                 \n",
    "                     \n",
    " tes  <- coefs\n",
    " tes.se <- coefs.se\n",
    " tes.cor <- C.coefs\n",
    " crit.val <- qtmax(tes.cor,B,alpha);\n",
    "                     \n",
    " tes.ucb  <- tes + crit.val * tes.se;\n",
    " tes.lcb  <- tes - crit.val * tes.se;\n",
    "\n",
    " tes.uci  <- tes + qnorm(1-alpha/2) * tes.se;\n",
    " tes.lci  <- tes + qnorm(alpha/2) * tes.se;\n",
    "\n",
    "                     \n",
    " return(list(beta.hat=coefs, ghat.lower.point=tes.lci, ghat.upper.point=tes.uci,\n",
    "           ghat.lower=tes.lcb, ghat.upper= tes.ucb, crit.val=crit.val ))\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "numeric-prime",
   "metadata": {
    "papermill": {
     "duration": 0.795965,
     "end_time": "2021-04-05T23:20:23.551724",
     "exception": false,
     "start_time": "2021-04-05T23:20:22.755759",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "res<-group_average_treatment_effect(X=X,Y=RobustSignal)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "oriented-license",
   "metadata": {
    "papermill": {
     "duration": 0.438302,
     "end_time": "2021-04-05T23:20:24.015967",
     "exception": false,
     "start_time": "2021-04-05T23:20:23.577665",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "## this code is taken from L1 14.382 taught at MIT\n",
    "## author: Mert Demirer\n",
    "options(repr.plot.width=10, repr.plot.height=8)\n",
    "\n",
    "tes<-res$beta.hat\n",
    "tes.lci<-res$ghat.lower.point\n",
    "tes.uci<-res$ghat.upper.point\n",
    "\n",
    "tes.lcb<-res$ghat.lower\n",
    "tes.ucb<-res$ghat.upper\n",
    "tes.lev<-c('0%-20%', '20%-40%','40%-60%','60%-80%','80%-100%')\n",
    "\n",
    "plot( c(1,5), las = 2, xlim =c(0.6, 5.4), ylim = c(.05, 2.09),  type=\"n\",xlab=\"Income group\", \n",
    "     ylab=\"Average Effect on NET TFA (per 10 K)\", main=\"Group Average Treatment Effects on NET TFA\", xaxt=\"n\");\n",
    "axis(1, at=1:5, labels=tes.lev);\n",
    "for (i in 1:5)\n",
    "{;\n",
    " rect(i-0.2, tes.lci[i], i+0.2,  tes.uci[i], col = NA,  border = \"red\", lwd = 3);    \n",
    " rect(i-0.2, tes.lcb[i], i+0.2, tes.ucb[i], col = NA,  border = 4, lwd = 3 );   \n",
    " segments(i-0.2, tes[i], i+0.2, tes[i], lwd = 5 );\n",
    "};\n",
    "abline(h=0);\n",
    "\n",
    "legend(2.5, 2.0, c('Regression Estimate', '95% Simultaneous Confidence Interval', '95% Pointwise Confidence Interval'), col = c(1,4,2), lwd = c(4,3,3), horiz = F, bty = 'n', cex=0.8);\n",
    "\n",
    "dev.off()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "promotional-township",
   "metadata": {
    "papermill": {
     "duration": 0.050484,
     "end_time": "2021-04-05T23:20:24.094587",
     "exception": false,
     "start_time": "2021-04-05T23:20:24.044103",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "least_squares_splines<-function(X,Y,max_knot,norder,nderiv,...) {\n",
    "  ## Create technical regressors\n",
    "    cv.bsp<-rep(0,max_knot-1)\n",
    "    for (knot in 2:max_knot) {\n",
    "      breaks<- quantile(X, c(0:knot)/knot)\n",
    "      formula.bsp \t<- Y ~ bsplineS(X, breaks =breaks, norder = norder, nderiv = nderiv)[ ,-1]\n",
    "      fit\t<- lm(formula.bsp);\n",
    "      cv.bsp[knot-1]\t\t<- sum( (fit$res / (1 - hatvalues(fit)) )^2);\n",
    "    }\n",
    "    ## Number of knots chosen by cross-validation\n",
    "    cv_knot<-which.min(cv.bsp)+1\n",
    "     breaks<- quantile(X, c(0:cv_knot)/cv_knot)\n",
    "   formula.bsp \t<- Y ~ bsplineS(X, breaks =breaks, norder = norder, nderiv = 0)[ ,-1]\n",
    "   fit\t<- lm(formula.bsp); \n",
    "   \n",
    "    \n",
    "   return(list(cv_knot=cv_knot,fit=fit))\n",
    "}\n",
    "\n",
    "\n",
    "least_squares_series<-function(X, Y,max_degree,...) {\n",
    " \n",
    "  cv.pol<-rep(0,max_degree)\n",
    "  for (degree in 1:max_degree) {\n",
    "    formula.pol \t<- Y ~ poly(X, degree)\n",
    "    fit\t<- lm(formula.pol );\n",
    "    cv.pol[degree]\t\t<- sum( (fit$res / (1 - hatvalues(fit)) )^2);\n",
    "  }\n",
    "  ## Number of knots chosen by cross-validation\n",
    "  cv_degree<-which.min(cv.pol)\n",
    "  ## Estimate coefficients\n",
    "  formula.pol \t<- Y ~ poly(X, cv_degree)\n",
    "  fit\t<- lm(formula.pol);\n",
    "    \n",
    "    \n",
    "  return(list(fit=fit,cv_degree=cv_degree))\n",
    "}\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "divine-breach",
   "metadata": {
    "papermill": {
     "duration": 0.055202,
     "end_time": "2021-04-05T23:20:24.177456",
     "exception": false,
     "start_time": "2021-04-05T23:20:24.122254",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "msqrt <- function(C)\n",
    "  {;\n",
    "  C.eig <- eigen(C);\n",
    "  return(C.eig$vectors %*% diag(sqrt(C.eig$values)) %*% solve(C.eig$vectors));\n",
    "  };\n",
    "\n",
    "\n",
    "tboot<-function(regressors_grid, Omega.hat ,alpha, B=10000) {\n",
    "  \n",
    "    \n",
    "   numerator_grid<-regressors_grid%*%msqrt( Omega.hat)\n",
    "   denominator_grid<-sqrt(diag(regressors_grid%*% Omega.hat%*%t(regressors_grid)))\n",
    "    \n",
    "   norm_numerator_grid<-numerator_grid\n",
    "   for (k in 1:dim(numerator_grid)[1]) {\n",
    "       norm_numerator_grid[k,]<-numerator_grid[k,]/denominator_grid[k]\n",
    "   }\n",
    "   \n",
    "   tmaxs <- apply(abs( norm_numerator_grid%*% matrix(rnorm(dim(numerator_grid)[2]*B), nrow = dim(numerator_grid)[2], ncol = B)), 2, max)\n",
    "   return(quantile(tmaxs, 1-alpha))\n",
    " \n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "alleged-board",
   "metadata": {
    "papermill": {
     "duration": 0.048834,
     "end_time": "2021-04-05T23:20:24.256326",
     "exception": false,
     "start_time": "2021-04-05T23:20:24.207492",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "second_stage<-function(fs.hat,Y,D,X,max_degree=3,norder=4,nderiv=0,ss_method=\"poly\",min_cutoff=0.01,alpha=0.05,eps=0.1,...) {\n",
    "  \n",
    "     X_grid = seq(min(X),max(X),eps)\n",
    "     mu1.hat<-fs.hat[[\"mu1.hat\"]]\n",
    "    mu0.hat<-fs.hat[[\"mu0.hat\"]]\n",
    "    # propensity score\n",
    "    s.hat<-fs.hat[[\"s.hat\"]]\n",
    "    s.hat<-sapply(s.hat,max,min_cutoff)\n",
    "    ### Construct Orthogonal Signal \n",
    "\n",
    "    RobustSignal<-(Y - mu1.hat)*D/s.hat - (Y - mu0.hat)*(1-D)/(1-s.hat) + mu1.hat - mu0.hat\n",
    "\n",
    "\n",
    "  \n",
    "   # Estimate the target function using least squares series\n",
    "   if (ss_method == \"ortho_poly\") {\n",
    "    res<-least_squares_series(X=X,Y=RobustSignal,eps=0.1,max_degree=max_degree)\n",
    "    fit<-res$fit\n",
    "    cv_degree<-res$cv_degree\n",
    "    regressors_grid<-cbind( rep(1,length(X_grid)), poly(X_grid,cv_degree))   \n",
    "    \n",
    "  } \n",
    "  if (ss_method == \"splines\") {\n",
    "      \n",
    "    res<-least_squares_splines(X=X,Y=RobustSignal,eps=0.1,max_knot=max_knot,norder=norder,nderiv=nderiv)\n",
    "    fit<-res$fit\n",
    "    cv_knot<-res$cv_knot\n",
    "    breaks<- quantile(X, c(0:cv_knot)/cv_knot)  \n",
    "    regressors_grid<-cbind( rep(1,length(X_grid)), bsplineS(X_grid, breaks =breaks, norder = norder, nderiv = nderiv)[ ,-1])\n",
    "    degree=cv_knot\n",
    " \n",
    "\n",
    "  }\n",
    "\n",
    "\n",
    "  g.hat<-regressors_grid%*%coef(fit)\n",
    " \n",
    "    \n",
    "  HCV.coefs <- vcovHC(fit, type = 'HC') \n",
    "  #Omega.hat<-white_vcov(regressors,Y,b.hat=coef(fit))\n",
    "  standard_error<-sqrt(diag(regressors_grid%*% HCV.coefs%*%t(regressors_grid)))\n",
    "  ### Lower Pointwise CI\n",
    "  ghat.lower.point<-g.hat+qnorm(alpha/2)*standard_error\n",
    "  ### Upper Pointwise CI\n",
    "  ghat.upper.point<-g.hat+qnorm(1-alpha/2)*standard_error\n",
    " \n",
    "   max_tstat<-tboot(regressors_grid=regressors_grid,  Omega.hat=HCV.coefs,alpha=alpha)\n",
    "  \n",
    "    \n",
    "  ## Lower Uniform CI\n",
    "  ghat.lower<-g.hat-max_tstat*standard_error\n",
    "  ## Upper Uniform CI\n",
    "  ghat.upper<-g.hat+max_tstat*standard_error\n",
    "  return(list(ghat.lower=ghat.lower,g.hat=g.hat, ghat.upper=ghat.upper,fit=fit,ghat.lower.point=ghat.lower.point,\n",
    "              ghat.upper.point=ghat.upper.point,X_grid=X_grid,degree=cv_degree))\n",
    "    \n",
    "\n",
    "\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "incoming-image",
   "metadata": {
    "papermill": {
     "duration": 0.051062,
     "end_time": "2021-04-05T23:20:24.335000",
     "exception": false,
     "start_time": "2021-04-05T23:20:24.283938",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "make_plot<-function(res,lowy,highy,degree,ss_method = \"series\",uniform=TRUE,...) {\n",
    "  show(degree)\n",
    "\n",
    "  title=\"Effect of 401(k) on Net TFA\"\n",
    "  X_grid=res$X_grid\n",
    "  len = length(X_grid)\n",
    "  \n",
    "    \n",
    "  if (uniform) {\n",
    "       group <-c(rep(\"UCI\",len), rep(\"PCI\",len), rep(\"Estimate\",len),rep(\"PCIL\",len),rep(\"UCIL\",len))\n",
    "       group_type<- c(rep(\"CI\",len), rep(\"CI\",len), rep(\"Estimate\",len),rep(\"CI\",len),rep(\"CI\",len))\n",
    "       group_ci_type<-c(rep(\"Uniform\",len), rep(\"Point\",len), rep(\"Uniform\",len),rep(\"Point\",len),rep(\"Uniform\",len))\n",
    "\n",
    "     df<-data.frame(income=rep(X_grid,5), outcome = c(res$ghat.lower,res$ghat.lower.point,res$g.hat,res$ghat.upper.point,res$ghat.upper),group=group, group_col = group_type,group_line =group_ci_type )\n",
    "       p<-ggplot(data=df)+\n",
    "    aes(x=income,y=outcome,colour=group )+\n",
    "    theme_bw()+\n",
    "    xlab(\"Income (log scale)\")+\n",
    "    ylab(\"Net TFA, (thousand dollars)\")+\n",
    "    scale_colour_manual(values=c(\"black\",\"blue\",\"blue\",\"blue\",\"blue\"))+\n",
    "    theme(plot.title = element_text(hjust = 0.5),text=element_text(size=20,  family=\"serif\"))+\n",
    "    theme(legend.title=element_blank())+\n",
    "    theme(legend.position=\"none\")+\n",
    "    ylim(low=lowy,high=highy)+\n",
    "    geom_line(aes(linetype = group_line),size=1.5)+\n",
    "    scale_linetype_manual(values=c(\"dashed\",\"solid\"))+\n",
    "    ggtitle(title)  \n",
    "  } \n",
    "\n",
    "  if (!uniform) {\n",
    "      group <-c( rep(\"PCI\",len), rep(\"Estimate\",len),rep(\"PCIL\",len))\n",
    "      group_type<- c(rep(\"CI\",len), rep(\"Estimate\",len),rep(\"CI\",len))\n",
    "     group_ci_type<-c(rep(\"Point\",len), rep(\"Uniform\",len),rep(\"Point\",len))\n",
    "      \n",
    "       df<-data.frame(income=rep(X_grid,3), outcome = c(res$ghat.lower.point,res$g.hat,res$ghat.upper.point),group=group, group_col = group_type,group_line =group_ci_type )\n",
    "   \n",
    "        p<-ggplot(data=df)+\n",
    "    aes(x=income,y=outcome,colour=group )+\n",
    "    theme_bw()+\n",
    "    xlab(\"Income (log scale)\")+\n",
    "    ylab(\"Net TFA, (thousand dollars)\")+\n",
    "    scale_colour_manual(values=c(\"black\",\"blue\",\"blue\",\"blue\",\"blue\"))+\n",
    "    theme(plot.title = element_text(hjust = 0.5),text=element_text(size=20,  family=\"serif\"))+\n",
    "    theme(legend.title=element_blank())+\n",
    "    theme(legend.position=\"none\")+\n",
    "    ylim(low=lowy,high=highy)+\n",
    "    geom_line(aes(linetype = group_line),size=1.5)+\n",
    "    scale_linetype_manual(values=c(\"dashed\",\"solid\"))+\n",
    "    ggtitle(title)   \n",
    "\n",
    "   }\n",
    "\n",
    "\n",
    " \n",
    "  return(p)\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "infinite-delhi",
   "metadata": {
    "papermill": {
     "duration": 0.301006,
     "end_time": "2021-04-05T23:20:24.666001",
     "exception": false,
     "start_time": "2021-04-05T23:20:24.364995",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "### assign outcome = net tfa, treatment = e401 eligibility, covariate x = income\n",
    "res_ortho_rf=second_stage(fs.hat=fs.hat.rf,X=X,D=D,Y=Y,ss_method=\"ortho_poly\",max_degree=3)\n",
    "### assign outcome = net tfa, treatment = e401 eligibility, covariate x = income\n",
    "#res_ortho_lasso=second_stage(fs.hat=fs.hat.lasso,X=X,D=D,Y=Y,ss_method=\"splines\",max_degree=3)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "regional-remark",
   "metadata": {
    "papermill": {
     "duration": 0.02735,
     "end_time": "2021-04-05T23:20:24.741505",
     "exception": false,
     "start_time": "2021-04-05T23:20:24.714155",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "#plot findings:\n",
    "\n",
    "-- black solid line shows estimated function $p(x)' \\widehat{\\beta}$\n",
    "\n",
    "-- blue dashed lines show pointwise confidence bands for this function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "reduced-south",
   "metadata": {
    "papermill": {
     "duration": 0.640067,
     "end_time": "2021-04-05T23:20:25.409141",
     "exception": false,
     "start_time": "2021-04-05T23:20:24.769074",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "p<-make_plot(res_ortho_rf,degree=res_ortho_rf$degree,ss_method=\"ortho_poly\",uniform=FALSE, lowy=-10,highy=20)\n",
    "options(repr.plot.width=15, repr.plot.height=10)\n",
    "print(p)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "global-column",
   "metadata": {
    "papermill": {
     "duration": 0.031587,
     "end_time": "2021-04-05T23:20:25.473821",
     "exception": false,
     "start_time": "2021-04-05T23:20:25.442234",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "plot findings:\n",
    "\n",
    "-- black solid line shows estimated function $p(x)' \\widehat{\\beta}$\n",
    "\n",
    "-- blue dashed lines show pointwise confidence bands for this function. I.e., for each fixed point $x_0$, i.e., $x_0=1$, they cover $p(x_0)'\\beta_0$ with probability 0.95\n",
    "\n",
    "-- blue solid lines show  uniform confidence bands for this function. I.e.,  they cover the whole function $x \\rightarrow p(x)'\\beta_0$ with probability 0.95 on some compact range"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "broke-trunk",
   "metadata": {
    "papermill": {
     "duration": 0.475554,
     "end_time": "2021-04-05T23:20:25.980854",
     "exception": false,
     "start_time": "2021-04-05T23:20:25.505300",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "p<-make_plot(res_ortho_rf,degree=res_ortho_rf$degree,ss_method=\"ortho_poly\",uniform=TRUE,lowy=-10,highy=25) \n",
    "options(repr.plot.width=15, repr.plot.height=10)\n",
    "print(p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "structured-puppy",
   "metadata": {
    "papermill": {
     "duration": 0.038791,
     "end_time": "2021-04-05T23:20:26.056830",
     "exception": false,
     "start_time": "2021-04-05T23:20:26.018039",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "R",
   "language": "R",
   "name": "ir"
  },
  "language_info": {
   "codemirror_mode": "r",
   "file_extension": ".r",
   "mimetype": "text/x-r-source",
   "name": "R",
   "pygments_lexer": "r",
   "version": "3.6.3"
  },
  "papermill": {
   "default_parameters": {},
   "duration": 178.921245,
   "end_time": "2021-04-05T23:20:26.203059",
   "environment_variables": {},
   "exception": null,
   "input_path": "__notebook__.ipynb",
   "output_path": "__notebook__.ipynb",
   "parameters": {},
   "start_time": "2021-04-05T23:17:27.281814",
   "version": "2.3.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
